%% Implement RWMH posterior sampler and form predictive distribution for validation study estimates

function [pred_intervals, coverage_rate, percentiles, nu_draws, accept_rate_all] = rwmh_posterior_sampler(hat_theta, rep_base_se2, hat_vartheta, rep_val_se2, alpha, num_samples, prior_mean, seed)
% INPUT: baseline and validation study estimates and variances, significance level for predictive intervals alpha, 
% number of MCMC samples num_samples
%
% OUTPUT: 
%   pred_intervals - N x 2 matrix of (1-alpha) predictive intervals [lower, upper] for hat_theta_i,J+1
%   coverage_rate  - scalar empirical coverage frequency of predictive intervals 
%   percentiles    - N x 1 vector representing percentile of hat_vartheta(i) in its predictive distribution

% Initialization
[N, J] = size(hat_theta);
burnin_rate = 0.2; % 20% burn-in
burnin = round(burnin_rate * num_samples);
kept_samples = num_samples - burnin;
rng(seed);

% Store the posterior samples
posterior_samples = zeros(N, kept_samples);
nu_draws = zeros(N, kept_samples);
accept_rate_all = zeros(N, 1);

% Loop through each row
for i = 1:N
    % RWMH algorithm to sample from posterior of nu_i
    [nu_samples, accept_rate] = sample_nu_i(hat_theta(i,:)', rep_base_se2(i,:)', prior_mean, num_samples);
    
    % Sample from posterior of tau_i for each nu_i sample
    tau_samples = sample_tau_i(hat_theta(i,:)', rep_base_se2(i,:)', nu_samples);
    
    % Sample from predictive distribution
    theta_pred_samples = sample_theta_pred(tau_samples, nu_samples, rep_val_se2(i));
    
    % Store samples after burn-in
    posterior_samples(i,:) = theta_pred_samples(burnin+1:end);
    nu_draws(i,:) = nu_samples(burnin+1:end);
    accept_rate_all(i,:) = accept_rate;
end

% Calculate predictive intervals
lower_bound = (alpha/2);
upper_bound = 1 - (alpha/2);
pred_intervals = zeros(N, 2);

for i = 1:N
    pred_intervals(i,1) = quantile(posterior_samples(i,:), lower_bound);
    pred_intervals(i,2) = quantile(posterior_samples(i,:), upper_bound);
end

% Calculate the percentile of hat_vartheta in the predictive distribution
percentiles = zeros(N, 1);
for i = 1:N
    % Calculate the proportion of posterior samples that are smaller than hat_vartheta(i)
    percentiles(i) = mean(posterior_samples(i,:) < hat_vartheta(i));
end

% Calculate the coverage rate of predictive intervals
in_interval = (hat_vartheta >= pred_intervals(:,1)) & (hat_vartheta <= pred_intervals(:,2));
coverage_rate = mean(in_interval);

end

function [nu_samples, accept_rate] = sample_nu_i(theta_i, sigma2_i, prior_mean, num_samples)
% SAMPLE_NU_I - Implements RWMH sampler for nu_i
%
% Inputs:
%   theta_i       - J x 1 vector of theta estimates
%   sigma2_i      - J x 1 vector of variance estimates
%   prior_mean    - prior mean for nu-i 
%   num_samples   - number of MCMC samples
%
% Outputs:
%   nu_samples    - S x 1 vector of MCMC samples for nu_i

J = length(theta_i);
nu_samples = zeros(num_samples, 1);

% Initialize from prior
nu_samples(1) = 1/gamrnd(3, 1/(2*prior_mean)); % Sample from Inverse Gamma

% Tune the proposal variance
c = 1; % Starting scaling factor
accepted = 0;
burnin_rate = 0.2; % 20% burn-in
burnin = round(burnin_rate * num_samples);

for s = 2:num_samples
    % Propose new value from normal centered at current value
    proposal_var = c * prior_mean^2;
    nu_proposed = normrnd(nu_samples(s-1), sqrt(proposal_var));
    
    % Reject negative proposals immediately
    if nu_proposed <= 0
        nu_samples(s) = nu_samples(s-1);
        continue;
    end
    
    % Calculate prior ratio (Inverse Gamma)
    log_prior_ratio = log_ig_pdf(nu_proposed, 3, 2*prior_mean) - ...
                      log_ig_pdf(nu_samples(s-1), 3, 2*prior_mean);
    
    % Calculate likelihood ratio
    log_likelihood_ratio = log_likelihood(theta_i, sigma2_i, nu_proposed) - ...
                          log_likelihood(theta_i, sigma2_i, nu_samples(s-1));
    
    % Calculate acceptance probability
    log_accept_prob = log_prior_ratio + log_likelihood_ratio;
    accept_prob = min(1, exp(log_accept_prob));
    
    % Accept or reject proposal
    if rand < accept_prob
        nu_samples(s) = nu_proposed;
        if s > burnin  
            accepted = accepted + 1;
        end
    else
        nu_samples(s) = nu_samples(s-1);
    end
    
end

accept_rate = accepted/(num_samples - burnin);

end

function log_val = log_ig_pdf(x, a, b)
% LOG_IG_PDF - Log PDF of Inverse Gamma distribution
%
% Inputs:
%   x - point to evaluate
%   a - shape parameter
%   b - scale parameter
%
% Outputs:
%   log_val - log of the PDF value

if x <= 0
    log_val = -Inf;
else
    log_val = a * log(b) - gammaln(a) - (a + 1) * log(x) - b / x;
end

end

function log_lik = log_likelihood(theta_i, sigma2_i, nu_i)
% LOG_LIKELIHOOD - Calculates log likelihood for nu_i
%
% Inputs:
%   theta_i  - J x 1 vector of theta estimates for subject i
%   sigma2_i - J x 1 vector of variance estimates for subject i
%   nu_i     - scalar value of nu_i
%
% Outputs:
%   log_lik  - log likelihood value

J = length(theta_i);
v_plus_sigma2 = nu_i + sigma2_i;
inv_v_plus_sigma2 = 1 ./ v_plus_sigma2;
sum_inv = sum(inv_v_plus_sigma2);

% First term of the log likelihood
term1 = -0.5 * J * log(2*pi) - 0.5 * sum(log(v_plus_sigma2));

% Second term of the log likelihood
term2 = -0.5 * sum((theta_i.^2) .* inv_v_plus_sigma2);

% Third term of the log likelihood
term3 = -0.5 * log(sum_inv);

% Fourth term of the log likelihood
weighted_sum = sum(inv_v_plus_sigma2 .* theta_i);
term4 = 0.5 * (weighted_sum^2) / sum_inv;

log_lik = term1 + term2 + term3 + term4;

end

function tau_samples = sample_tau_i(theta_i, sigma2_i, nu_samples)
% SAMPLE_TAU_I - Samples from posterior of tau_i given nu_i
%
% Inputs:
%   theta_i        - J x 1 vector of theta estimates for subject i
%   sigma2_i       - J x 1 vector of variance estimates for subject i
%   nu_samples     - S x 1 vector of nu_i samples
% Outputs:
%   tau_samples    - S x 1 vector of tau_i samples

S = length(nu_samples);
tau_samples = zeros(S, 1);

for s = 1:S
    nu_i = nu_samples(s);
    inv_v_plus_sigma2 = 1 ./ (nu_i + sigma2_i);
    
    % Calculate posterior precision and mean
    V_tau_i_inv = sum(inv_v_plus_sigma2);
    V_tau_i = 1 / V_tau_i_inv;
    
    mean_tau_i = V_tau_i * (sum(inv_v_plus_sigma2 .* theta_i));
    
    % Sample from posterior normal distribution
    tau_samples(s) = normrnd(mean_tau_i, sqrt(V_tau_i));
end

end

function theta_pred_samples = sample_theta_pred(tau_samples, nu_samples, sigma2_val)
% SAMPLE_THETA_PRED - Samples from predictive distribution
%
% Inputs:
%   tau_samples - S x 1 vector of tau_i samples
%   nu_samples  - S x 1 vector of nu_i samples
%   sigma2_val  - scalar validation variance sigma^2_{i,J+1}
%
% Outputs:
%   theta_pred_samples - S x 1 vector of predictive samples

S = length(tau_samples);
theta_pred_samples = zeros(S, 1);

for s = 1:S
    % Sample from normal predictive distribution
    theta_pred_samples(s) = normrnd(tau_samples(s), sqrt(nu_samples(s) + sigma2_val));
end

end
